{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from PT_HGNN.data import *\n",
    "from PT_HGNN.model import *\n",
    "from warnings import filterwarnings\n",
    "filterwarnings(\"ignore\")\n",
    "\n",
    "import numpy as np\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = renamed_load(open(\"dataset/graph_Engineering.pk\", 'rb'))\n",
    "\n",
    "# train_range = {t: True for t in graph.times if t != None and t < 2015}\n",
    "# valid_range = {t: True for t in graph.times if t != None and t >= 2015  and t <= 2016}\n",
    "# test_range  = {t: True for t in graph.times if t != None and t > 2016}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "graph.node_feature['paper'].keys()\n",
    "graph.get_meta_graph()\n",
    "types = graph.get_meta_graph()\n",
    "types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "targetid2sourceidlist = defaultdict( #target_type\n",
    "                                    lambda: defaultdict(  #source_type\n",
    "                                        lambda: defaultdict(  #relation_type\n",
    "                                            lambda: defaultdict(  #target_id\n",
    "                                                list #source_id list\n",
    "                                                ))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "targetid2degree = defaultdict( #target_type\n",
    "                                lambda: defaultdict(  #source_type\n",
    "                                    lambda: defaultdict(  #relation_type\n",
    "                                        lambda: defaultdict(  #target_id\n",
    "                                            int # each node corresponding node degree\n",
    "                                            ))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for target_type in graph.edge_list:\n",
    "    te = graph.edge_list[target_type]\n",
    "    for source_type in te:\n",
    "        tes = te[source_type]\n",
    "        for relation_type in tes:\n",
    "            tesr = tes[relation_type]\n",
    "            for target_id in tesr:\n",
    "                for source_id in tesr[target_id]:\n",
    "                    targetid2sourceidlist[target_type][source_type][relation_type][target_id].append(source_id)\n",
    "                targetid2degree[target_type][source_type][relation_type][target_id] = len(targetid2sourceidlist[target_type][source_type][relation_type][target_id])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_ppr_node(inode, target_list, target_degree, source_list, source_degree, alpha, epsilon):\n",
    "    \n",
    "    alpha_eps = alpha * epsilon\n",
    "    \n",
    "    p = {}\n",
    "    r_t = {}\n",
    "    r_t[inode] = alpha\n",
    "    r_s = {}\n",
    "    q_t = [inode]\n",
    "    q_s = []\n",
    "    \n",
    "    t2s = True\n",
    "    while len(q_t) > 0 or len(q_s) > 0:\n",
    "        \n",
    "        if len(q_t) == 0:\n",
    "            t2s = False\n",
    "        elif len(q_s) == 0:\n",
    "            t2s = True\n",
    "            \n",
    "        if t2s:\n",
    "            unode = q_t.pop()\n",
    "            \n",
    "            res = r_t[unode] if unode in r_t else 0.0\n",
    "#             if unode in p:\n",
    "#                 p[unode] += res\n",
    "#             else:\n",
    "#                 p[unode] = res       \n",
    "            r_t[unode] = 0.0\n",
    "            \n",
    "            for vnode in target_list[unode]:\n",
    "                _val = (1 - alpha) * res / target_degree[unode]\n",
    "                if vnode in r_s:\n",
    "                    r_s[vnode] += _val\n",
    "                else:\n",
    "                    r_s[vnode] = _val\n",
    "\n",
    "                res_vnode = r_s[vnode] if vnode in r_s else 0.0\n",
    "                if res_vnode >= alpha_eps * source_degree[vnode]:\n",
    "                    if vnode not in q_s:\n",
    "                        q_s.append(vnode)\n",
    "        else:\n",
    "            unode = q_s.pop()\n",
    "            \n",
    "            res = r_s[unode] if unode in r_s else 0.0\n",
    "            if unode in p:\n",
    "                p[unode] += res\n",
    "            else:\n",
    "                p[unode] = res\n",
    "            \n",
    "            r_s[unode] = 0.0\n",
    "            \n",
    "            for vnode in source_list[unode]:\n",
    "                _val = (1 - alpha) * res / source_degree[unode]\n",
    "                if vnode in r_t:\n",
    "                    r_t[vnode] += _val\n",
    "                else:\n",
    "                    r_t[vnode] = _val\n",
    "                \n",
    "                res_vnode = r_t[vnode] if vnode in r_t else 0.0\n",
    "                if res_vnode >= alpha_eps * target_degree[vnode]:\n",
    "                    if vnode not in q_t:\n",
    "                        q_t.append(vnode)\n",
    "    \n",
    "    \n",
    "#     p = sorted(p.items(), key = lambda k : k[1], reverse=True)\n",
    "    return p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_ppr_list = defaultdict( #target_type\n",
    "                            lambda: defaultdict(  #source_type\n",
    "                                lambda: defaultdict(  #relation_type\n",
    "                                    lambda: defaultdict(  #target_id\n",
    "                                        lambda: defaultdict( #source_id(\n",
    "                                            lambda: float # ppr weight\n",
    "                                        )))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### cs alpha = 0.1 eplison = 0.001\n",
    "\n",
    "### material science  alpha = 0.1 eplison = 0.0001 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sum_ppr():\n",
    "    for target_type in graph.edge_list:\n",
    "        te = graph.edge_list[target_type]\n",
    "        for source_type in te:\n",
    "            tes = te[source_type]\n",
    "            for relation_type in tes:\n",
    "                tesr = tes[relation_type]\n",
    "                target_list = targetid2sourceidlist[target_type][source_type][relation_type]\n",
    "                target_degree = targetid2degree[target_type][source_type][relation_type]\n",
    "                source_list = targetid2sourceidlist[source_type][target_type][\"rev_\" + relation_type]\n",
    "                source_degree = targetid2degree[source_type][target_type][\"rev_\" + relation_type]\n",
    "                for target_id in tesr:\n",
    "#                     ppr_val = calc_ppr_node(target_id, target_list, target_degree, source_list, source_degree, 0.25, 0.001)\n",
    "                    ppr_val = calc_ppr_node(target_id, target_list, target_degree, source_list, source_degree, 0.1, 0.001)\n",
    "                    for source_id in target_list[target_id]:\n",
    "                        if source_id in ppr_val:\n",
    "                            edge_ppr_list[target_type][source_type][relation_type][target_id][source_id] = ppr_val[source_id]\n",
    "                        else:\n",
    "                            edge_ppr_list[target_type][source_type][relation_type][target_id][source_id] = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "t1 = time.time()\n",
    "sum_ppr()\n",
    "print(time.time() - t1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filter_edge_list = defaultdict( #target_type\n",
    "                            lambda: defaultdict(  #source_type\n",
    "                                lambda: defaultdict(  #relation_type\n",
    "                                    lambda: defaultdict(  #target_id\n",
    "                                        lambda: defaultdict( #source_id(\n",
    "                                            lambda: int # time\n",
    "                                        )))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dir(graph))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "graph.times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pre_range   = {t: True for t in graph.times if t != None and t < 2014}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pre_target_nodes   = []\n",
    "for p_id, _time in enumerate(list(graph.node_feature['paper']['time'])):\n",
    "    if _time in pre_range:\n",
    "        pre_target_nodes += [p_id]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for target_type in graph.edge_list:\n",
    "    te = graph.edge_list[target_type]\n",
    "    for source_type in te:\n",
    "        tes = te[source_type]\n",
    "        for relation_type in tes:\n",
    "            tesr = tes[relation_type]\n",
    "            if target_type == 'paper' and source_type == 'paper':\n",
    "                for target_id in tesr:\n",
    "                    tesrt = tesr[target_id]\n",
    "                    sids = tesrt.keys()\n",
    "                    \n",
    "                    sid_w = dict()\n",
    "                    for sid in sids:\n",
    "                        if tesrt[sid] not in pre_range:\n",
    "                            filter_edge_list[target_type][source_type][relation_type][target_id][sid] = tesrt[sid]\n",
    "                        else:\n",
    "                            sid_w[sid] = edge_ppr_list[target_type][source_type][relation_type][target_id][sid]\n",
    "                    \n",
    "                    n = len(sid_w)\n",
    "                    if n <= 8:\n",
    "                        for sid in sids:\n",
    "                            filter_edge_list[target_type][source_type][relation_type][target_id][sid] = tesrt[sid]\n",
    "                    else:\n",
    "                        idxs = sorted(sid_w.items(), key=lambda k : k[1], reverse=True)\n",
    "                        for sid, ppr_score in idxs[:8]:\n",
    "                            filter_edge_list[target_type][source_type][relation_type][target_id][sid] = tesrt[sid]\n",
    "            else:\n",
    "                filter_edge_list[target_type][source_type][relation_type] = graph.edge_list[target_type][source_type][relation_type]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.edge_list = filter_edge_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dill\n",
    "dill.dump(graph, open('dataset/graph_Engineering_ppr.pk', 'wb'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gpt-gnn",
   "language": "python",
   "name": "gpt"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}